package com.dny.asmtop;

import jdk.internal.org.objectweb.asm.FieldVisitor;
import jdk.internal.org.objectweb.asm.Type;
import jdk.internal.org.objectweb.asm.commons.GeneratorAdapter;
import jdk.internal.org.objectweb.asm.commons.Method;

import java.io.FileOutputStream;
import java.io.IOException;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.nio.file.Path;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;

import static jdk.internal.org.objectweb.asm.Opcodes.*;
import static jdk.internal.org.objectweb.asm.Type.getInternalName;
import static jdk.internal.org.objectweb.asm.Type.getType;
import static jdk.internal.org.objectweb.asm.commons.Method.getMethod;

/**
 * Created by jlutt on 2018-01-15.
 * <br>
 * 类构造器
 *
 * @author jlutt
 */
@SuppressWarnings("unchecked")
public final class ClassBuilder<T> {

  /**
   * 无类名情况下的默认类名，同时有个class的包名
   */
  private static final String DEFAULT_CLASS_NAME = ClassBuilder.class.getPackage().getName() + ".Class";

  /**
   * 无类名情况下的类计数器，避免生成重复的类名
   */
  private static final AtomicInteger COUNTER = new AtomicInteger();

  /**
   * 类加载器
   */
  private final ASMClassLoader classLoader;

  /**
   * 主类型
   */
  private final Class<T> mainClass;

  /**
   * 其他继承类（接口）
   */
  private final List<Class<?>> otherClasses;

  /**
   * 生成的类名称
   */
  private String className;

  private ClassAnnotationWrite classAnnotationWrite;

  /**
   * 属性
   */
  private final Map<String, Class<?>> fields = new LinkedHashMap<>();

  /**
   * 字段对应的注解
   */
  private final Map<String, FieldAnnotationWrite> fieldAnnotations = new LinkedHashMap<>();

  /**
   * 静态属性
   */
  private final Map<String, Class<?>> staticFields = new LinkedHashMap<>();

  /**
   * 静态属性的值
   */
  private final Map<String, Object> staticConstants = new LinkedHashMap<>();

  /**
   * 方法
   */
  private final Map<Method, Command> methods = new LinkedHashMap<>();

  private final Map<Method, MethodAnnotationWrite> methodAnnotations = new LinkedHashMap<>();

  /**
   * 静态方法
   */
  private final Map<Method, Command> staticMethods = new LinkedHashMap<>();

  /**
   * 字节码class文件保存路径，方便测试使用
   */
  private Path classSavePath;

  private ClassBuilder(ASMClassLoader classLoader, Class<T> type) {
    this(classLoader, type, Collections.EMPTY_LIST);
  }

  private ClassBuilder(ASMClassLoader classLoader, Class<T> mainType, List<Class<?>> types) {
    this.classLoader = classLoader;
    this.mainClass = mainType;
    this.otherClasses = types;
  }

  /**
   * 生成构造器
   *
   * @param classLoader
   * @param type
   * @param <T>
   * @return
   */
  public static <T> ClassBuilder<T> create(ASMClassLoader classLoader, Class<T> type) {
    return new ClassBuilder<>(classLoader, type);
  }

  /**
   * 生成构造器
   *
   * @param classLoader
   * @param mainType
   * @param types
   * @param <T>
   * @return
   */
  public static <T> ClassBuilder<T> create(ASMClassLoader classLoader, Class<T> mainType, List<Class<?>> types) {
    return new ClassBuilder<T>(classLoader, mainType, types);
  }

  /**
   * 设置类名
   *
   * @param className
   * @return
   */
  public ClassBuilder<T> withClassName(String className) {
    this.className = className;
    return this;
  }

  public ClassBuilder<T> withAnnotation(ClassAnnotationWrite annotationWrite) {
    this.classAnnotationWrite = annotationWrite;
    return this;
  }

  /**
   * 设置字节码保存路径
   *
   * @param classSavePath class文件保存路径
   * @return
   */
  public ClassBuilder<T> savePath(Path classSavePath) {
    this.classSavePath = classSavePath;
    return this;
  }

  /**
   * 增加属性
   *
   * @param fieldName
   * @param fieldClass
   * @return
   */
  public ClassBuilder<T> addField(String fieldName, Class<?> fieldClass) {
    this.fields.put(fieldName, fieldClass);
    return this;
  }

  public ClassBuilder<T> addField(String fieldName, Class<?> fieldClass, FieldAnnotationWrite annotationWrite) {
    this.fields.put(fieldName, fieldClass);
    this.fieldAnnotations.put(fieldName, annotationWrite);
    return this;
  }

  /**
   * 增加静态属性
   *
   * @param fieldName
   * @param fieldClass
   * @param fieldValue
   * @return
   */
  public ClassBuilder<T> addStaticField(String fieldName, Class<?> fieldClass, Object fieldValue) {
    this.staticFields.put(fieldName, fieldClass);
    this.staticConstants.put(fieldName, fieldValue);
    return this;
  }

  /**
   * 增加方法
   *
   * @param method
   * @param command
   * @return
   */
  public ClassBuilder<T> addMethod(Method method, Command command) {
    return addMethod(method, command, null);
  }

  /**
   * 增加方法
   *
   * @param method
   * @param command
   * @param aw
   * @return
   */
  public ClassBuilder<T> addMethod(Method method, Command command, MethodAnnotationWrite aw) {
    this.methods.put(method, command);
    if (aw != null) {
      this.methodAnnotations.put(method, aw);
    }
    return this;
  }

  /**
   * 增加方法
   *
   * @param methodName 方法名
   * @param command
   * @return
   */
  public ClassBuilder<T> addMethod(String methodName, Command command) {
    return addMethod(methodName, command, null);
  }

  /**
   * 增加方法
   *
   * @param methodName
   * @param command
   * @param aw
   * @return
   */
  public ClassBuilder<T> addMethod(String methodName, Command command, MethodAnnotationWrite aw) {
    //如果方法中有参数，则直接增加
    if (methodName.contains("(")) {
      Method method = Method.getMethod(methodName);
      return addMethod(method, command);
    }

    //没有则寻找父类或者接口中对应的方法，获取名称后进行增加
    Method foundMethod = null;

    List<List<java.lang.reflect.Method>> listOfMethods = new ArrayList<>();
    listOfMethods.add(Arrays.asList(Object.class.getMethods()));
    listOfMethods.add(Arrays.asList(mainClass.getMethods()));
    listOfMethods.add(Arrays.asList(mainClass.getDeclaredMethods()));
    for (Class<?> type : otherClasses) {
      listOfMethods.add(Arrays.asList(type.getMethods()));
      listOfMethods.add(Arrays.asList(type.getDeclaredMethods()));
    }
    for (List<java.lang.reflect.Method> list : listOfMethods) {
      for (java.lang.reflect.Method m : list) {
        if (m.getName().equals(methodName)) {
          Method method = getMethod(m);
          if (foundMethod != null && !method.equals(foundMethod)) {
            throw new IllegalArgumentException("方法 " + method + " 和 " + foundMethod + " 冲突");
          }
          foundMethod = method;
        }
      }
    }

    if (foundMethod == null) {
      throw new RuntimeException("无法找到方法:" + methodName);
    }
    return addMethod(foundMethod, command, aw);
  }

  /**
   * 增加方法
   *
   * @param methodName    方法名称
   * @param returnClass   返回类型
   * @param argumentTypes 参数类型数组
   * @param command       待处理的方法body
   * @return
   */
  public ClassBuilder<T> addMethod(String methodName,
                                   Class<?> returnClass,
                                   List<? extends Class<?>> argumentTypes,
                                   Command command) {
    return addMethod(methodName, returnClass, argumentTypes, command, null);
  }

  /**
   * 增加方法
   *
   * @param methodName
   * @param returnClass
   * @param argumentTypes
   * @param command
   * @param aw
   * @return
   */
  public ClassBuilder<T> addMethod(String methodName,
                                   Class<?> returnClass,
                                   List<? extends Class<?>> argumentTypes,
                                   Command command,
                                   MethodAnnotationWrite aw) {
    Type[] types = new Type[argumentTypes.size()];
    for (int i = 0; i < argumentTypes.size(); i++) {
      types[i] = getType(argumentTypes.get(i));
    }
    return addMethod(new Method(methodName, getType(returnClass), types), command, aw);
  }

  private Class<T> loadNewClass(ClassKey key) {
    ASMClassWriter cw = ASMClassWriter.create(classLoader);

    String newClassName;
    if (className == null) {
      newClassName = DEFAULT_CLASS_NAME + COUNTER.incrementAndGet();
    } else {
      newClassName = className;
    }

    Type classType = getType('L' + newClassName.replace('.', '/') + ';');

    //父类和接口
    String[] internalNames = new String[1 + otherClasses.size()];
    internalNames[0] = getInternalName(mainClass);
    for (int i = 0; i < otherClasses.size(); i++) {
      internalNames[1 + i] = getInternalName(otherClasses.get(i));
    }

    if (mainClass.isInterface()) {
      cw.visit(V1_8, ACC_PUBLIC + ACC_FINAL + ACC_SUPER,
          classType.getInternalName(),
          null,
          "java/lang/Object",
          internalNames);
    } else {
      cw.visit(V1_8, ACC_PUBLIC + ACC_FINAL + ACC_SUPER,
          classType.getInternalName(),
          null,
          internalNames[0],
          Arrays.copyOfRange(internalNames, 1, internalNames.length));
    }

    {
      if (classAnnotationWrite != null) {
        classAnnotationWrite.write(cw);
      }
    }

    {
      //构造函数
      Method m = getMethod("void <init> ()");
      GeneratorAdapter g = new GeneratorAdapter(ACC_PUBLIC, m, null, null, cw);
      g.loadThis();

      if (mainClass.isInterface()) {
        g.invokeConstructor(getType(Object.class), m);
      } else {
        g.invokeConstructor(getType(mainClass), m);
      }

      g.returnValue();
      g.endMethod();
    }

    {
      //属性
      for (String field : fields.keySet()) {
        Class<?> fieldClass = fields.get(field);
        FieldVisitor fv = cw.visitField(ACC_PUBLIC, field, getType(fieldClass).getDescriptor(), null, null);
        if (fieldAnnotations.containsKey(field)) {
          FieldAnnotationWrite annotationWrite = fieldAnnotations.get(field);
          if (annotationWrite != null) {
            annotationWrite.write(fv);
          }
        }
      }
    }

    {
      //静态方法
      for (Method m : staticMethods.keySet()) {
        try {

        } catch (Exception e) {
          throw new RuntimeException(e);
        }
      }
    }

    {
      //方法
      for (Method m : methods.keySet()) {
        try {
          GeneratorAdapter mv = new GeneratorAdapter(ACC_PUBLIC, m, null, null, cw);

          if (methodAnnotations.containsKey(m)) {
            MethodAnnotationWrite annotationWrite = methodAnnotations.get(m);
            if (annotationWrite != null) {
              annotationWrite.write(mv);
            }
          }

          MethodContext ctx = new MethodContext(classLoader, mv, classType,
              mainClass, otherClasses, fields, staticConstants, m.getArgumentTypes(), m, methods, staticMethods);

          Command command = methods.get(m);
          ASMMethodUtils.generatorAndCast(ctx, command, m.getReturnType());
          mv.returnValue();

          mv.endMethod();
        } catch (Exception e) {
          throw new RuntimeException(e);
        }
      }
    }

    {
      //静态属性
      for (String staticField : staticFields.keySet()) {
        cw.visitField(ACC_PUBLIC + ACC_STATIC, staticField, getType(staticFields.get(staticField)).getDescriptor(), null, null);
      }
    }

    {
      //常量
      for (String staticField : staticConstants.keySet()) {
        cw.visitField(ACC_PUBLIC + ACC_STATIC, staticField, getType(staticConstants.get(staticField).getClass()).getDescriptor(), null, null);
      }
    }

    cw.visitEnd();

    byte[] bytes = cw.toByteArray();

    if (classSavePath != null) {
      try (FileOutputStream fos = new FileOutputStream(classSavePath.resolve(newClassName + ".class").toFile())) {
        fos.write(cw.toByteArray());
      } catch (IOException e) {
        throw new RuntimeException(e);
      }
    }

    Class<?> newClazz = classLoader.loadClass(newClassName, key, bytes);
    return (Class<T>) newClazz;
  }

  /**
   * 生成类
   *
   * @return
   */
  public Class<T> build() {
    //必须同步执行
    synchronized (classLoader) {
      ClassKey key = new ClassKey(mainClass, otherClasses, fields, methods, staticMethods);
      Class<?> cachedClass = classLoader.loadClassByKey(key);
      if (cachedClass != null) {
        return (Class<T>) cachedClass;
      }

      Class<T> newClass = loadNewClass(key);
      //设置静态字段为空
      for (String staticField : staticConstants.keySet()) {
        Object staticValue = staticConstants.get(staticField);
        try {
          Field field = newClass.getField(staticField);
          field.set(null, staticValue);
        } catch (NoSuchFieldException | IllegalAccessException e) {
          throw new AssertionError(e);
        }
      }

      return newClass;
    }
  }

  /**
   * 生成实例
   *
   * @return
   */
  public T buildInstance() {
    try {
      return build().newInstance();
    } catch (InstantiationException | IllegalAccessException e) {
      throw new RuntimeException(e);
    }
  }

  /**
   * 生成实例
   *
   * @param constructorParameters
   * @return
   */
  public T buildInstance(Object... constructorParameters) {
    Class[] constructorParameterTypes = new Class[constructorParameters.length];
    for (int i = 0; i < constructorParameters.length; i++) {
      constructorParameterTypes[i] = constructorParameters[i].getClass();
    }
    return buildInstance(constructorParameterTypes, constructorParameters);
  }

  /**
   * 生成实例
   *
   * @param constructorParameterTypes
   * @param constructorParameters
   * @return
   */
  public T buildInstance(Class[] constructorParameterTypes, Object[] constructorParameters) {
    try {
      return build().getConstructor(constructorParameterTypes).newInstance(constructorParameters);
    } catch (InstantiationException | IllegalAccessException | NoSuchMethodException | InvocationTargetException e) {
      throw new RuntimeException(e);
    }
  }


  static class ClassKey<T> {
    private final Class<T> mainClass;
    private final List<Class<?>> otherClasses;
    private final Map<String, Class<?>> fields;
    private final Map<Method, Command> expressionMap;
    private final Map<Method, Command> expressionStaticMap;

    public ClassKey(Class<T> mainClass,
                    List<Class<?>> otherClasses,
                    Map<String, Class<?>> fields,
                    Map<Method, Command> expressionMap,
                    Map<Method, Command> expressionStaticMap) {
      this.mainClass = mainClass;
      this.otherClasses = otherClasses;
      this.fields = fields;
      this.expressionMap = expressionMap;
      this.expressionStaticMap = expressionStaticMap;
    }

    public Class<T> getMainClass() {
      return mainClass;
    }

    public List<Class<?>> getOtherClasses() {
      return otherClasses;
    }

    @Override
    public boolean equals(Object o) {
      if (this == o) return true;
      if (o == null || getClass() != o.getClass()) return false;
      ClassKey<?> classKey = (ClassKey<?>) o;
      return Objects.equals(mainClass, classKey.mainClass) &&
          Objects.equals(otherClasses, classKey.otherClasses) &&
          Objects.equals(fields, classKey.fields) &&
          Objects.equals(expressionMap, classKey.expressionMap) &&
          Objects.equals(expressionStaticMap, classKey.expressionStaticMap);
    }

    @Override
    public int hashCode() {
      return Objects.hash(mainClass, otherClasses, fields, expressionMap, expressionStaticMap);
    }
  }
}
